knitr::opts_chunk$set(fig.width=12, fig.height=8, warning=FALSE, message=FALSE)
library(R2jags)
library(MASS)
library(ggplot2)
library(mcmcplots)
library(MCMCvis)
library(geoR)
load("swim_time.RData")
ns <- nrow(Y)
nt <- ncol(Y)
JAGS_swimmer_model = function(){
for (i in 1:ns) {
for (j in 1:nt) {
Y[i, j] ~ dnorm(mean[i, j], tau_e)
mean[i, j] <- a[i] + b[i] *j
}
}
for (i in 1:ns){
a[i]~ dnorm(22, tau_a)
b[i]~ dnorm(0, tau_b)
Y_pred[i] ~ dnorm(a[i] + b[i] * 7, tau_e)
z[i] <- (Y_pred[i] == min(Y_pred))
}
tau_a~ dgamma(0.1, 0.1)
tau_b ~ dgamma(0.1, 0.1)
tau_e ~ dgamma(0.1, 0.1)
sigma2_a <- 1/tau_a
sigma2_b <- 1/tau_b
sigma2_e <- 1/tau_e
}
fit_swimmer_model = jags(
data = list(Y = Y, ns = ns, nt = nt),
inits = list(list(a = rep(20,4), b = rep(0,4)),
list(a = rep(30,4), b = rep(1,4)),
list(a = rep(25,4), b = rep(-1,4)),
list(a = rep(10,4), b = rep(0,4)),
list(a = rep(15,4), b = rep(0,4))),
parameters.to.save = c("a","b","sigma2_a","sigma2_b", "sigma2_e", "Y_pred","z"),
n.chains = 5,
n.iter = 10000,
n.burnin = 1000,
model.file = JAGS_swimmer_model
)
## Compiling model graph
## Resolving undeclared variables
## Allocating nodes
## Graph information:
## Observed stochastic nodes: 24
## Unobserved stochastic nodes: 15
## Total graph size: 158
##
## Initializing model
chains = as.mcmc(fit_swimmer_model)
MCMCtrace(chains, pdf = FALSE, params = c("a","b","sigma2_a","sigma2_b", "sigma2_e", "Y_pred"))
gelman.diag(chains, multivariate = FALSE)
## Potential scale reduction factors:
##
## Point est. Upper C.I.
## Y_pred[1] 1.000 1.00
## Y_pred[2] 1.001 1.00
## Y_pred[3] 1.000 1.00
## Y_pred[4] 1.000 1.00
## a[1] 1.001 1.00
## a[2] 1.000 1.00
## a[3] 0.999 1.00
## a[4] 1.000 1.00
## b[1] 1.001 1.00
## b[2] 0.999 1.00
## b[3] 0.999 1.00
## b[4] 1.000 1.00
## deviance 1.000 1.00
## sigma2_a 1.073 1.08
## sigma2_b 1.112 1.11
## sigma2_e 1.001 1.00
## z[1] 0.999 1.00
## z[2] 1.292 1.41
## z[3] 1.000 1.00
## z[4] 1.123 1.13
To obtain PPD, draw Y_pred in the JAGS model, densityplots shown for Y_pred is the posterior predictive distribution. We also obtain mean and 95% CI for Y_pred.
summary(chains)
##
## Iterations = 1:8992
## Thinning interval = 9
## Number of chains = 5
## Sample size per chain = 1000
##
## 1. Empirical mean and standard deviation for each variable,
## plus standard error of the mean:
##
## Mean SD Naive SE Time-series SE
## Y_pred[1] 22.63518 0.21028 0.0029738 0.0028316
## Y_pred[2] 23.58841 0.20996 0.0029693 0.0030091
## Y_pred[3] 22.90530 0.20769 0.0029371 0.0029060
## Y_pred[4] 23.38292 0.21124 0.0029874 0.0030618
## a[1] 23.23018 0.14169 0.0020038 0.0020042
## a[2] 23.10697 0.14217 0.0020106 0.0020321
## a[3] 22.62481 0.14291 0.0020210 0.0020217
## a[4] 23.73654 0.14061 0.0019885 0.0019795
## b[1] -0.08550 0.03594 0.0005083 0.0005084
## b[2] 0.06848 0.03628 0.0005130 0.0005186
## b[3] 0.03990 0.03604 0.0005097 0.0005028
## b[4] -0.05035 0.03606 0.0005100 0.0005050
## deviance -33.66875 7.94349 0.1123379 0.1123504
## sigma2_a 2.84818 5.35777 0.0757703 0.0769834
## sigma2_b 0.10153 0.18209 0.0025752 0.0026037
## sigma2_e 0.02333 0.01569 0.0002218 0.0002218
## z[1] 0.81880 0.38522 0.0054479 0.0054292
## z[2] 0.00100 0.03161 0.0004470 0.0004469
## z[3] 0.17720 0.38188 0.0054005 0.0052896
## z[4] 0.00300 0.05470 0.0007735 0.0009765
##
## 2. Quantiles for each variable:
##
## 2.5% 25% 50% 75% 97.5%
## Y_pred[1] 22.221198 22.50165 22.63237 22.77192 23.04596
## Y_pred[2] 23.186087 23.44890 23.58941 23.72560 23.99414
## Y_pred[3] 22.498039 22.76842 22.90894 23.03995 23.31065
## Y_pred[4] 22.974650 23.24814 23.38420 23.51291 23.81232
## a[1] 22.950542 23.14163 23.23085 23.32085 23.50107
## a[2] 22.827278 23.01511 23.10855 23.20021 23.38726
## a[3] 22.344728 22.53499 22.62379 22.71632 22.90611
## a[4] 23.455967 23.65011 23.73883 23.82785 24.00768
## b[1] -0.154192 -0.10921 -0.08582 -0.06265 -0.01477
## b[2] -0.004245 0.04547 0.06807 0.09164 0.13977
## b[3] -0.031192 0.01642 0.03996 0.06357 0.10981
## b[4] -0.119985 -0.07388 -0.05099 -0.02774 0.02113
## deviance -46.284975 -39.20200 -34.53935 -29.06704 -16.77206
## sigma2_a 0.535146 1.12699 1.80325 3.07215 10.69745
## sigma2_b 0.019693 0.03968 0.06251 0.10690 0.42356
## sigma2_e 0.011234 0.01670 0.02095 0.02707 0.04531
## z[1] 0.000000 1.00000 1.00000 1.00000 1.00000
## z[2] 0.000000 0.00000 0.00000 0.00000 0.00000
## z[3] 0.000000 0.00000 0.00000 0.00000 1.00000
## z[4] 0.000000 0.00000 0.00000 0.00000 0.00000
In the JAGS model, we define Z as indicator function that individual i has the fastest posterior predictive value. In this case, we output the posterior mean of Z, as the probablity \(Pr(Y_i^*=min(Y_1^*,...,Y_4^*)|\boldsymbol{Y})\). Based on posterior mean of z = (0.8364, 0.0012, 0.1594, 0.003), we would recommend swimmer 1.
Linear regression.
#import data
X = UScrime[,1:15]
Y = UScrime[,16]
n = nrow(UScrime)
p = ncol(UScrime) - 1
df = list()
df$X = X
df$Y = Y
df$n = n
df$p = p
JAGS_BLR_flat = function(){
# Likelihood
for(i in 1:n){
Y[i] ~ dnorm(mu[i],inv_sigma2)
mu[i] <- beta_0 + inprod(X[i,],beta)
}
# Prior for beta
for(j in 1:p){
beta[j] ~ dnorm(0,0.0001)
#non-informative priors
}
# Prior for intercept
beta_0 ~ dnorm(0, 0.0001)
# Prior for the inverse variance
inv_sigma2 ~ dgamma(0.0001, 0.0001)
sigma2 <- 1.0/inv_sigma2
}
fit_JAGS_flat = jags(data=df,
inits=list(list(beta = rnorm(p),
beta_0 = 0,
inv_sigma2 = 1),
list(beta = rnorm(p),
beta_0 = 1,
inv_sigma2 = 2),
list(beta = rnorm(p),
beta_0 = 2,
inv_sigma2 = 2),
list(beta = rnorm(p),
beta_0 = 10,
inv_sigma2 = 5),
list(beta = rnorm(p),
beta_0 = 20,
inv_sigma2 = 1)),
parameters.to.save = c("beta_0","beta","sigma2"),
n.chains=5,
n.iter=10000,
n.burnin=1000,
model.file=JAGS_BLR_flat)
## Compiling model graph
## Resolving undeclared variables
## Allocating nodes
## Graph information:
## Observed stochastic nodes: 47
## Unobserved stochastic nodes: 17
## Total graph size: 950
##
## Initializing model
chains_2a = as.mcmc(fit_JAGS_flat)
plot(chains_2a)
gelman.diag(chains_2a)
## Potential scale reduction factors:
##
## Point est. Upper C.I.
## beta[1] 1.000 1
## beta[2] 0.999 1
## beta[3] 1.000 1
## beta[4] 1.000 1
## beta[5] 1.000 1
## beta[6] 0.999 1
## beta[7] 1.000 1
## beta[8] 1.001 1
## beta[9] 1.000 1
## beta[10] 0.999 1
## beta[11] 1.000 1
## beta[12] 1.000 1
## beta[13] 1.000 1
## beta[14] 1.000 1
## beta[15] 1.000 1
## beta_0 1.001 1
## deviance 1.000 1
## sigma2 1.000 1
##
## Multivariate psrf
##
## 1.01
##output 95% credible interval
summary(chains_2a)
##
## Iterations = 1:8992
## Thinning interval = 9
## Number of chains = 5
## Sample size per chain = 1000
##
## 1. Empirical mean and standard deviation for each variable,
## plus standard error of the mean:
##
## Mean SD Naive SE Time-series SE
## beta[1] 6.2769 5.713 0.08079 0.08050
## beta[2] -4.6720 88.397 1.25012 1.23507
## beta[3] 13.3712 8.189 0.11580 0.11861
## beta[4] 24.9122 13.418 0.18976 0.18974
## beta[5] -12.8245 14.783 0.20906 0.20530
## beta[6] 0.9340 1.750 0.02475 0.02448
## beta[7] -4.3725 2.089 0.02954 0.02985
## beta[8] -2.3473 1.699 0.02403 0.02485
## beta[9] -0.1931 0.739 0.01045 0.01042
## beta[10] 0.4376 5.046 0.07136 0.06946
## beta[11] 7.2377 10.392 0.14696 0.14582
## beta[12] 0.1785 1.344 0.01901 0.01845
## beta[13] 5.5855 2.821 0.03990 0.03992
## beta[14] -14.8201 181.978 2.57356 2.57433
## beta[15] -0.8186 7.566 0.10699 0.10615
## beta_0 -24.2766 213.411 3.01809 3.01893
## deviance 662.4300 6.534 0.09240 0.09008
## sigma2 79914.8145 21161.755 299.27241 299.37872
##
## 2. Quantiles for each variable:
##
## 2.5% 25% 50% 75% 97.5%
## beta[1] -4.9182 2.5486 6.3255 9.9486 1.756e+01
## beta[2] -172.0069 -64.8561 -5.1441 55.2904 1.708e+02
## beta[3] -2.7158 8.0150 13.4763 18.8642 2.945e+01
## beta[4] -1.9181 16.3079 25.0483 33.6330 5.091e+01
## beta[5] -41.6939 -22.6782 -12.8314 -2.9717 1.690e+01
## beta[6] -2.4874 -0.2190 0.9350 2.1132 4.352e+00
## beta[7] -8.4158 -5.7711 -4.3975 -2.9619 -3.048e-01
## beta[8] -5.7585 -3.4747 -2.3252 -1.2196 8.994e-01
## beta[9] -1.6545 -0.6773 -0.2015 0.3003 1.286e+00
## beta[10] -9.5560 -2.9348 0.4491 3.7481 1.036e+01
## beta[11] -13.4632 0.4081 7.3432 13.9251 2.742e+01
## beta[12] -2.4923 -0.7225 0.1744 1.0678 2.850e+00
## beta[13] -0.1074 3.6987 5.5768 7.4331 1.106e+01
## beta[14] -205.6813 -77.2050 -8.9474 54.7239 1.839e+02
## beta[15] -15.7806 -5.8559 -0.8993 4.3155 1.421e+01
## beta_0 -215.3003 -87.7920 -18.7683 48.1790 1.781e+02
## deviance 652.1919 657.7971 661.7272 666.3166 6.765e+02
## sigma2 48251.8838 64838.1451 76522.0964 91354.9055 1.304e+05
Cross validation
#split data into training and test set
split_data = function(df,train_test_ratio = 1,random=TRUE){
n_train = floor(df$n*train_test_ratio/(1+train_test_ratio))
n_test = df$n - n_train
if(random){
train_idx = sample(1:n,n_train,replace = FALSE)
test_idx = setdiff(1:n,train_idx)
}
else{
train_idx = 1:n_train
test_idx = n_train+1:n_test
}
df_t = list()
df_t$Y_train = df$Y[train_idx]
df_t$X_train = df$X[train_idx,,drop=FALSE]
df_t$X_test = df$X[test_idx,,drop=FALSE]
df_t$n_train = n_train
df_t$n_test = n_test
df_t$p = df$p
return(list(df_t=df_t,Y_test=df$Y[test_idx]))
}
pred = split_data(df, random = FALSE)
##define a predictive JAGS
JAGS_BLR_flat_pred = function(){
# Likelihood
for(i in 1:n_train){
Y_train[i] ~ dnorm(mu_train[i],inv_sigma2)
mu_train[i] <- beta_0 + inprod(X_train[i,],beta)
# same as beta_0 + X[i,1]*beta[1] + ... + X[i,p]*beta[p]
}
# Prior for beta
for(j in 1:p){
beta[j] ~ dnorm(0,0.0001)
#non-informative priors
}
# Prior for intercept
beta_0 ~ dnorm(0, 0.0001)
# Prior for the inverse variance
inv_sigma2 ~ dgamma(0.0001, 0.0001)
sigma2 <- 1.0/inv_sigma2
#prediction
# Predictions
for(i in 1:n_test){
Y_test[i] ~ dnorm(mu_test[i],inv_sigma2)
mu_test[i] <- beta_0 + inprod(X_test[i,],beta)
}
}
fit_JAGS_flat_pred = jags(data=pred$df_t,
inits=list(list(beta = rnorm(p),
beta_0 = 0,
inv_sigma2 = 1),
list(beta = rnorm(p),
beta_0 = 1,
inv_sigma2 = 2),
list(beta = rnorm(p),
beta_0 = 2,
inv_sigma2 = 2),
list(beta = rnorm(p),
beta_0 = 10,
inv_sigma2 = 5),
list(beta = rnorm(p),
beta_0 = 20,
inv_sigma2 = 1)),
parameters.to.save = c("beta_0","beta","sigma2", "Y_test"),
n.chains=5,
n.iter=10000,
n.burnin=1000,
model.file=JAGS_BLR_flat_pred)
## Compiling model graph
## Resolving undeclared variables
## Allocating nodes
## Graph information:
## Observed stochastic nodes: 23
## Unobserved stochastic nodes: 41
## Total graph size: 951
##
## Initializing model
chains_2b = as.mcmc(fit_JAGS_flat_pred)
##plot the predictive values
result = summary(chains_2b)
q = result$quantiles
dtfr = as.data.frame(cbind(result$quantiles))
##compare
pred$Y_test
## [1] 968 523 1993 342 1216 1043 696 373 754 1072 923 653 1272 831
## [15] 566 826 1151 880 542 823 1030 455 508 849
fit_JAGS_flat_pred$BUGSoutput$median$Y_test
## [1] 683.7445 919.9540 1700.4795 187.7366 864.4677 2089.7941 785.0522
## [8] 251.8214 1288.7909 651.3395 749.8402 1275.7072 756.4813 568.5020
## [15] 744.4670 990.3948 1397.8558 538.7887 503.3671 891.7807 572.7658
## [22] 527.6371 1337.7897 576.2757
JAGS_BLR_SpikeSlab = function(){
# Likelihood
for(i in 1:n){
Y[i] ~ dnorm(mu[i],inv_sigma2)
mu[i] <- beta_0 + inprod(X[i,],beta)
# same as beta_0 + X[i,1]*beta[1] + ... + X[i,p]*beta[p]
}
#Prior for beta_j
for(j in 1:p){
beta[j] ~ dnorm(0,inv_tau2[j])
inv_tau2[j] <- (1-gamma[j])*1000+gamma[j]*0.01
gamma[j] ~ dbern(0.5)
}
# Prior for intercept
beta_0 ~ dnorm(0, 0.0001)
# Prior for the inverse variance
inv_sigma2 ~ dgamma(0.0001, 0.0001)
sigma2 <- 1.0/inv_sigma2
}
fit_JAGS_SpikeSlab = jags(data=df,
inits=list(list(beta = rnorm(p),
beta_0 = 0,
inv_sigma2 = 1),
list(beta = rnorm(p),
beta_0 = 1,
inv_sigma2 = 2),
list(beta = rnorm(p),
beta_0 = 2,
inv_sigma2 = 2),
list(beta = rnorm(p),
beta_0 = 10,
inv_sigma2 = 5),
list(beta = rnorm(p),
beta_0 = 20,
inv_sigma2 = 1)),
parameters.to.save = c("beta_0","beta","sigma2"),
n.chains=5,
n.iter=10000,
n.burnin=1000,
model.file=JAGS_BLR_flat)
## Compiling model graph
## Resolving undeclared variables
## Allocating nodes
## Graph information:
## Observed stochastic nodes: 47
## Unobserved stochastic nodes: 17
## Total graph size: 950
##
## Initializing model
plot(fit_JAGS_SpikeSlab)
chains_2c = as.mcmc(fit_JAGS_SpikeSlab)
summary(chains_2c)
##
## Iterations = 1:8992
## Thinning interval = 9
## Number of chains = 5
## Sample size per chain = 1000
##
## 1. Empirical mean and standard deviation for each variable,
## plus standard error of the mean:
##
## Mean SD Naive SE Time-series SE
## beta[1] 6.4076 5.593e+00 0.07910 0.07771
## beta[2] -3.9629 8.867e+01 1.25399 1.26927
## beta[3] 13.1462 8.064e+00 0.11404 0.10957
## beta[4] 24.7802 1.339e+01 0.18937 0.18891
## beta[5] -12.7592 1.472e+01 0.20820 0.21087
## beta[6] 0.9788 1.722e+00 0.02436 0.02459
## beta[7] -4.4125 2.061e+00 0.02914 0.02855
## beta[8] -2.3948 1.655e+00 0.02340 0.02340
## beta[9] -0.1881 7.436e-01 0.01052 0.01052
## beta[10] 0.4887 5.072e+00 0.07173 0.07190
## beta[11] 7.2933 1.050e+01 0.14853 0.14507
## beta[12] 0.2272 1.361e+00 0.01925 0.01947
## beta[13] 5.5454 2.809e+00 0.03973 0.03877
## beta[14] -15.3347 1.819e+02 2.57224 2.57313
## beta[15] -0.7006 7.624e+00 0.10783 0.10782
## beta_0 -25.5572 2.124e+02 3.00320 3.00398
## deviance 662.2133 6.583e+00 0.09310 0.09319
## sigma2 79194.1885 2.074e+04 293.29288 293.26671
##
## 2. Quantiles for each variable:
##
## 2.5% 25% 50% 75% 97.5%
## beta[1] -4.604e+00 2.7303 6.3836 10.1780 1.728e+01
## beta[2] -1.782e+02 -62.9237 -3.9262 55.6993 1.716e+02
## beta[3] -2.441e+00 7.8615 13.0893 18.3481 2.955e+01
## beta[4] -1.820e+00 15.9968 25.0303 33.6789 5.100e+01
## beta[5] -4.153e+01 -22.3380 -12.8812 -2.9491 1.644e+01
## beta[6] -2.488e+00 -0.1641 0.9898 2.0964 4.434e+00
## beta[7] -8.419e+00 -5.7426 -4.4307 -3.0464 -3.320e-01
## beta[8] -5.667e+00 -3.4871 -2.3867 -1.3170 8.623e-01
## beta[9] -1.677e+00 -0.6791 -0.1853 0.2968 1.270e+00
## beta[10] -9.591e+00 -2.9348 0.5176 3.8215 1.083e+01
## beta[11] -1.330e+01 0.1114 7.2177 14.3995 2.796e+01
## beta[12] -2.439e+00 -0.6832 0.2456 1.1338 2.872e+00
## beta[13] -5.786e-02 3.6825 5.5412 7.4345 1.110e+01
## beta[14] -2.098e+02 -79.4582 -9.0106 57.5367 1.783e+02
## beta[15] -1.585e+01 -5.6859 -0.7389 4.4010 1.425e+01
## beta_0 -2.107e+02 -86.9262 -19.6328 47.3713 1.716e+02
## deviance 6.520e+02 657.5292 661.4693 666.0229 6.769e+02
## sigma2 4.934e+04 64549.2147 75672.2612 89789.0624 1.286e+05
##############################################################################
##Prediction
##############################################################################
JAGS_BLR_SpikeSlab_pred = function(){
# Likelihood
for(i in 1:n_train){
Y_train[i] ~ dnorm(mu_train[i],inv_sigma2)
mu_train[i] <- beta_0 + inprod(X_train[i,],beta)
# same as beta_0 + X[i,1]*beta[1] + ... + X[i,p]*beta[p]
}
# Prior for beta
for(j in 1:p){
beta[j] ~ dnorm(0,inv_tau2[j])
inv_tau2[j] <- (1-gamma[j])*1000+gamma[j]*0.01
gamma[j] ~ dbern(0.5)
}
# Prior for intercept
beta_0 ~ dnorm(0, 0.0001)
# Prior for the inverse variance
inv_sigma2 ~ dgamma(0.0001, 0.0001)
sigma2 <- 1.0/inv_sigma2
#prediction
# Predictions
for(i in 1:n_test){
Y_test[i] ~ dnorm(mu_test[i],inv_sigma2)
mu_test[i] <- beta_0 + inprod(X_test[i,],beta)
}
}
fit_JAGS_SpikeSlab_pred = jags(data = pred$df_t,
inits=list(list(beta = rnorm(p),
beta_0 = 0,
inv_sigma2 = 1),
list(beta = rnorm(p),
beta_0 = 1,
inv_sigma2 = 2),
list(beta = rnorm(p),
beta_0 = 2,
inv_sigma2 = 2),
list(beta = rnorm(p),
beta_0 = 10,
inv_sigma2 = 5),
list(beta = rnorm(p),
beta_0 = 20,
inv_sigma2 = 1)),
parameters.to.save = c("beta_0","beta","sigma2","Y_test"),
n.chains=5,
n.iter=10000,
n.burnin=1000,
model.file=JAGS_BLR_SpikeSlab_pred)
## Compiling model graph
## Resolving undeclared variables
## Allocating nodes
## Graph information:
## Observed stochastic nodes: 23
## Unobserved stochastic nodes: 56
## Total graph size: 1071
##
## Initializing model
cbind(pred$Y_test, fit_JAGS_SpikeSlab_pred$BUGSoutput$median$Y_test)
## [,1] [,2]
## [1,] 968 784.9404
## [2,] 523 674.0086
## [3,] 1993 1440.5658
## [4,] 342 675.5429
## [5,] 1216 838.4122
## [6,] 1043 1949.4688
## [7,] 696 749.9512
## [8,] 373 611.3505
## [9,] 754 1104.4611
## [10,] 1072 741.8762
## [11,] 923 979.0033
## [12,] 653 1221.2600
## [13,] 1272 1034.6918
## [14,] 831 712.1543
## [15,] 566 585.0030
## [16,] 826 769.8610
## [17,] 1151 1082.5412
## [18,] 880 703.0680
## [19,] 542 588.6599
## [20,] 823 863.5766
## [21,] 1030 994.1638
## [22,] 455 608.5151
## [23,] 508 1084.0068
## [24,] 849 876.3603
plot(pred$Y_test, type = 'l')
lines(fit_JAGS_SpikeSlab_pred$BUGSoutput$median$Y_test, col= 'red')
v_loc = unique(gambia[,"x"])
v = match(gambia[,"x"],v_loc)